import torch
import torchvision
from torch import nn

vgg = torchvision.models.vgg16()
print(vgg)

vgg.classifier.add_module("new_layer", nn.Linear(1000, 10))
print(vgg)


# 模型的保存

torch.save(vgg, './models/vgg_method1.pth')

torch.save(vgg.state_dict(), './models/vgg_method2.pth')